# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Canonicalizes functions with multiple returns to use just one."""

import gast

from nvidia.dali._autograph.core import converter
from nvidia.dali._autograph.pyct import anno
from nvidia.dali._autograph.pyct import parser
from nvidia.dali._autograph.pyct import qual_names
from nvidia.dali._autograph.pyct import templates
from nvidia.dali._autograph.pyct.static_analysis import activity
from nvidia.dali._autograph.pyct.static_analysis.annos import NodeAnno


BODY_DEFINITELY_RETURNS = "BODY_DEFINITELY_RETURNS"
ORELSE_DEFINITELY_RETURNS = "ORELSE_DEFINITELY_RETURNS"
STMT_DEFINITELY_RETURNS = "STMT_DEFINITELY_RETURNS"


class _RewriteBlock(object):
    def __init__(self):
        self.definitely_returns = False


class ConditionalReturnRewriter(converter.Base):
    """Rewrites a pattern where it's unobvious that all paths return a value.

    This rewrite allows avoiding intermediate None return values.

    The following pattern:

        if cond:
          <block 1>
          return
        else:
          <block 2>
        <block 3>

    is converted to:

        if cond:
          <block 1>
          return
        else:
          <block 2>
          <block 3>

    and vice-versa (if the else returns, subsequent statements are moved under the
    if branch).
    """

    def visit_Return(self, node):
        self.state[_RewriteBlock].definitely_returns = True
        return node

    def _postprocess_statement(self, node):
        # If the node definitely returns (e.g. it's a with statement with a
        # return statement in it), then the current block also definitely returns.
        if anno.getanno(node, STMT_DEFINITELY_RETURNS, default=False):
            self.state[_RewriteBlock].definitely_returns = True

        # The special case: collapse a typical conditional return pattern into
        # a single conditional with possibly returns on both branches. This
        # reduces the use of None return values, which don't work with TF
        # conditionals.
        if isinstance(node, gast.If) and anno.getanno(node, BODY_DEFINITELY_RETURNS, default=False):
            return node, node.orelse
        elif isinstance(node, gast.If) and anno.getanno(
            node, ORELSE_DEFINITELY_RETURNS, default=False
        ):
            return node, node.body

        return node, None

    def _visit_statement_block(self, node, nodes):
        self.state[_RewriteBlock].enter()
        new_nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
        block_definitely_returns = self.state[_RewriteBlock].definitely_returns
        self.state[_RewriteBlock].exit()
        return new_nodes, block_definitely_returns

    def visit_While(self, node):
        node.test = self.visit(node.test)
        node.body, _ = self._visit_statement_block(node, node.body)
        node.orelse, _ = self._visit_statement_block(node, node.orelse)
        return node

    def visit_For(self, node):
        node.iter = self.visit(node.iter)
        node.target = self.visit(node.target)
        node.body, _ = self._visit_statement_block(node, node.body)
        node.orelse, _ = self._visit_statement_block(node, node.orelse)
        return node

    def visit_With(self, node):
        node.items = self.visit_block(node.items)
        node.body, definitely_returns = self._visit_statement_block(node, node.body)
        if definitely_returns:
            anno.setanno(node, STMT_DEFINITELY_RETURNS, True)
        return node

    def visit_Try(self, node):
        # We could decide whether a 'try' DEFINITELY_RETURNS based on its components
        # It is not clear whether we want to do anything with this given
        # a 'try' is likely to throw an exception in some circumstances.
        node.body, _ = self._visit_statement_block(node, node.body)
        node.orelse, _ = self._visit_statement_block(node, node.orelse)
        node.finalbody, _ = self._visit_statement_block(node, node.finalbody)
        node.handlers = self.visit_block(node.handlers)
        return node

    def visit_ExceptHandler(self, node):
        # To determine whether `try` DEFINITELY_RETURNS we need to revisit this.
        node.body, _ = self._visit_statement_block(node, node.body)
        return node

    def visit_If(self, node):
        node.test = self.visit(node.test)

        node.body, body_definitely_returns = self._visit_statement_block(node, node.body)
        if body_definitely_returns:
            anno.setanno(node, BODY_DEFINITELY_RETURNS, True)

        node.orelse, orelse_definitely_returns = self._visit_statement_block(node, node.orelse)
        if orelse_definitely_returns:
            anno.setanno(node, ORELSE_DEFINITELY_RETURNS, True)

        if body_definitely_returns and orelse_definitely_returns:
            self.state[_RewriteBlock].definitely_returns = True

        return node

    def visit_FunctionDef(self, node):
        node.args = self.visit(node.args)
        node.body, _ = self._visit_statement_block(node, node.body)
        return node


class _Block(object):
    def __init__(self):
        self.is_function = False
        self.return_used = False
        self.create_guard_next = False
        self.create_guard_now = False

    def __repr__(self):
        return "used: {}".format(self.return_used)


class _Function(object):
    def __init__(self):
        self.do_return_var_name = None
        self.retval_var_name = None

    def __repr__(self):
        return "return control: {}, return value: {}".format(
            self.do_return_var_name, self.retval_var_name
        )


class ReturnStatementsTransformer(converter.Base):
    """Lowers return statements into variables and conditionals.

    Specifically, the following pattern:

        <block 1>
        return val
        <block 2>

    is converted to:

        do_return = False
        retval = None

        <block 1>

        do_return = True
        retval = val

        if not do_return:
          <block 2>

        return retval

    The conversion adjusts loops as well:

        <block 1>
        while cond:
          <block 2>
          return retval

    is converted to:

        <block 1>
        while not do_return and cond:
          <block 2>
          do_return = True
          retval = val
    """

    def __init__(self, ctx, allow_missing_return):
        super(ReturnStatementsTransformer, self).__init__(ctx)
        self.allow_missing_return = allow_missing_return

    def visit_Return(self, node):
        for block in reversed(self.state[_Block].stack):
            block.return_used = True
            block.create_guard_next = True
            if block.is_function:
                break

        retval = node.value if node.value else parser.parse_expression("None")

        # Note: If `return <expr> raises, then the return is aborted.
        # The try-catch below ensures the variables remain consistent in that case.
        template = """
      try:
        do_return_var_name = True
        retval_var_name = retval
      except:
        do_return_var_name = False
        raise
    """
        node = templates.replace(
            template,
            do_return_var_name=self.state[_Function].do_return_var_name,
            retval_var_name=self.state[_Function].retval_var_name,
            retval=retval,
        )

        return node

    def _postprocess_statement(self, node):
        if not self.state[_Block].return_used:
            return node, None

        state = self.state[_Block]
        if state.create_guard_now:
            template = """
        if not do_return_var_name:
          original_node
      """
            (cond,) = templates.replace(
                template,
                do_return_var_name=self.state[_Function].do_return_var_name,
                original_node=node,
            )
            node, block = cond, cond.body
        else:
            node, block = node, None

        state.create_guard_now = state.create_guard_next
        state.create_guard_next = False

        return node, block

    def _visit_statement_block(self, node, nodes):
        self.state[_Block].enter()
        nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
        self.state[_Block].exit()
        return nodes

    def visit_While(self, node):
        node.test = self.visit(node.test)

        # Add the check for return to the loop condition.
        node.body = self._visit_statement_block(node, node.body)
        if self.state[_Block].return_used:
            node.test = templates.replace_as_expression(
                "not control_var and test",
                test=node.test,
                control_var=self.state[_Function].do_return_var_name,
            )

        node.orelse = self._visit_statement_block(node, node.orelse)
        return node

    def visit_For(self, node):
        node.iter = self.visit(node.iter)
        node.target = self.visit(node.target)

        # Add the check for return to the loop condition.
        node.body = self._visit_statement_block(node, node.body)
        if self.state[_Block].return_used:
            extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
            if extra_test is not None:
                extra_test = templates.replace_as_expression(
                    "not control_var and extra_test",
                    extra_test=extra_test,
                    control_var=self.state[_Function].do_return_var_name,
                )
            else:
                extra_test = templates.replace_as_expression(
                    "not control_var", control_var=self.state[_Function].do_return_var_name
                )
            anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test)

        node.orelse = self._visit_statement_block(node, node.orelse)
        return node

    def visit_With(self, node):
        node.items = self.visit_block(node.items)
        node.body = self._visit_statement_block(node, node.body)
        return node

    def visit_Try(self, node):
        node.body = self._visit_statement_block(node, node.body)
        node.orelse = self._visit_statement_block(node, node.orelse)
        node.finalbody = self._visit_statement_block(node, node.finalbody)
        node.handlers = self.visit_block(node.handlers)
        return node

    def visit_ExceptHandler(self, node):
        node.body = self._visit_statement_block(node, node.body)
        return node

    def visit_If(self, node):
        node.test = self.visit(node.test)
        node.body = self._visit_statement_block(node, node.body)
        node.orelse = self._visit_statement_block(node, node.orelse)
        return node

    def visit_FunctionDef(self, node):
        with self.state[_Function] as fn:
            with self.state[_Block] as block:
                block.is_function = True

                scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
                do_return_var_name = self.ctx.namer.new_symbol("do_return", scope.referenced)
                retval_var_name = self.ctx.namer.new_symbol("retval_", scope.referenced)
                fn.do_return_var_name = do_return_var_name
                fn.retval_var_name = retval_var_name

                node.body = self._visit_statement_block(node, node.body)

                if block.return_used:
                    if self.allow_missing_return:
                        # The function would have a single `with` node that wraps the
                        # entire body. If the function had a docstring, the body has two
                        # nodes, with the `with` as the second node.
                        wrapper_node = node.body[-1]
                        assert isinstance(
                            wrapper_node, gast.With
                        ), "This transformer requires the functions converter."

                        template = """
              do_return_var_name = False
              retval_var_name = ag__.UndefinedReturnValue()
              body
              return function_context.ret(retval_var_name, do_return_var_name)
            """

                        wrapper_node.body = templates.replace(
                            template,
                            body=wrapper_node.body,
                            do_return_var_name=do_return_var_name,
                            function_context=anno.getanno(node, "function_context_name"),
                            retval_var_name=retval_var_name,
                        )
                    else:
                        template = """
              body
              return retval_var_name
            """
                        node.body = templates.replace(
                            template,
                            body=node.body,
                            do_return_var_name=do_return_var_name,
                            retval_var_name=retval_var_name,
                        )

        return node


def transform(node, ctx, default_to_null_return=True):
    """Ensure a function has only a single return, at the end."""
    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx, None)

    # Note: Technically, these two could be merged into a single walk, but
    # keeping them separate helps with readability.
    node = ConditionalReturnRewriter(ctx).visit(node)

    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx, None)
    transformer = ReturnStatementsTransformer(ctx, allow_missing_return=default_to_null_return)
    node = transformer.visit(node)
    return node
